// David Eberly, Geometric Tools, Redmond WA 98052
// Copyright (c) 1998-2019
// Distributed under the Boost Software License, Version 1.0.
// http://www.boost.org/LICENSE_1_0.txt
// http://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
// File Version: 3.0.1 (2019/05/03)

#pragma once

#include <Physics/GteParticleSystem.h>
#include <cstring>
#include <set>

namespace gte
{

template <int N, typename Real>
class MassSpringArbitrary : public ParticleSystem<N, Real>
{
public:
    // Construction and destruction.  This class represents a set of M
    // masses that are connected by S springs with arbitrary topology.  The
    // function SetSpring(...) should be called for each spring that you
    // want in the system.
    virtual ~MassSpringArbitrary();
    MassSpringArbitrary(int numParticles, int numSprings, Real step);

    struct Spring
    {
        int particle0, particle1;
        Real constant, length;
    };

    // Member access.
    inline int GetNumSprings() const;
    void SetSpring(int index, Spring const& spring);
    inline Spring const& GetSpring(int index) const;

    // The default external force is zero.  Derive a class from this one to
    // provide nonzero external forces such as gravity, wind, friction,
    // and so on.  This function is called by Acceleration(...) to compute
    // the impulse F/m generated by the external force F.
    virtual Vector<N, Real> ExternalAcceleration(int i, Real time,
        std::vector<Vector<N, Real>> const& position,
        std::vector<Vector<N, Real>> const& velocity);

protected:
    // Callback for acceleration (ODE solver uses x" = F/m) applied to
    // particle i.  The positions and velocities are not necessarily
    // mPosition and mVelocity, because the ODE solver evaluates the
    // impulse function at intermediate positions.
    virtual Vector<N, Real> Acceleration(int i, Real time,
        std::vector<Vector<N, Real>> const& position,
        std::vector<Vector<N, Real>> const& velocity);

    std::vector<Spring> mSpring;

    // Each particle has an associated array of spring indices for those
    // springs adjacent to the particle.  The set elements are spring
    // indices, not indices of adjacent particles.
    std::vector<std::set<int>> mAdjacent;
};


template <int N, typename Real>
MassSpringArbitrary<N, Real>::~MassSpringArbitrary()
{
}

template <int N, typename Real>
MassSpringArbitrary<N, Real>::MassSpringArbitrary(int numParticles,
    int numSprings, Real step)
    :
    ParticleSystem<N, Real>(numParticles, step),
    mSpring(numSprings),
    mAdjacent(numParticles)
{
    std::memset(&mSpring[0], 0, numSprings * sizeof(Spring));
}

template <int N, typename Real> inline
int MassSpringArbitrary<N, Real>::GetNumSprings() const
{
    return static_cast<int>(mSpring.size());
}

template <int N, typename Real>
void MassSpringArbitrary<N, Real>::SetSpring(int index, Spring const& spring)
{
    mSpring[index] = spring;
    mAdjacent[spring.particle0].insert(index);
    mAdjacent[spring.particle1].insert(index);
}

template <int N, typename Real> inline
typename MassSpringArbitrary<N, Real>::Spring const&
MassSpringArbitrary<N, Real>::GetSpring(int index) const
{
    return mSpring[index];
}

template <int N, typename Real>
Vector<N, Real> MassSpringArbitrary<N, Real>::ExternalAcceleration(int, Real,
    std::vector<Vector<N, Real>> const&, std::vector<Vector<N, Real>> const&)
{
    return Vector<N, Real>::Zero();
}

template <int N, typename Real>
Vector<N, Real> MassSpringArbitrary<N, Real>::Acceleration(int i, Real time,
    std::vector<Vector<N, Real>> const& position,
    std::vector<Vector<N, Real>> const& velocity)
{
    // Compute spring forces on position X[i].  The positions are not
    // necessarily mkPosition, because the RK4 solver in ParticleSystem
    // evaluates the acceleration function at intermediate positions.

    Vector<N, Real> acceleration = ExternalAcceleration(i, time, position,
        velocity);

    for (auto adj : mAdjacent[i])
    {
        // Process a spring connected to particle i.
        Spring const& spring = mSpring[adj];
        Vector<N, Real> diff;
        if (i != spring.particle0)
        {
            diff = position[spring.particle0] - position[i];
        }
        else
        {
            diff = position[spring.particle1] - position[i];
        }

        Real ratio = spring.length / Length(diff);
        Vector<N, Real> force = spring.constant * ((Real)1 - ratio) * diff;
        acceleration += this->mInvMass[i] * force;
    }

    return acceleration;
}


}
